# 导入必要的库
import numpy as np  # 用于处理多维数组和矩阵运算
import logging  # 用于记录日志
from util import createXY  # 用于创建数据集的特征和标签
from sklearn.model_selection import train_test_split  # 用于拆分数据集为训练集和测试集
from sklearn.neighbors import KNeighborsClassifier  # sklearn中的K近邻分类器
import argparse  # 用于解析命令行参数
from tqdm import tqdm  # 用于在循环中显示进度条

# 配置logging, 确保能够打印正在运行的函数名
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 获取命令行参数
def get_args():
    parser = argparse.ArgumentParser(description='使用CPU或GPU训练模型。') # 创建命令行参数解析器对象
    parser.add_argument('-m', '--mode', type=str, required=True, choices=['cpu', 'gpu'], help='选择训练模式：CPU或GPU。')
    parser.add_argument('-f', '--feature', type=str, required=True, choices=['flat', 'vgg'], help='选择特征提取方法：flat或vgg。')
    parser.add_argument('-l', '--library', type=str, required=True, choices=['sklearn', 'faiss'], help='选择使用的库：sklearn或faiss。')
    args = parser.parse_args()
    return args

# 主函数，运行训练过程
def main():
    args = get_args()

    logging.info(f"选择特征提取方法是 {args.feature.upper()}")
    logging.info(f"选择使用的库是 {args.library.upper()}")

    # 载入和预处理数据
    try:
        X, y = createXY(train_folder="../data/train", dest_folder=".", method=args.feature)
    except ValueError as e:
        logging.error(f"数据加载错误: {e}")
        return
    except Exception as e:
        logging.error(f"未知错误: {e}")
        return

    if len(X) == 0:
        logging.error("未能加载任何图像数据，请检查数据路径和图像文件")
        return

    X = np.array(X).astype('float32')
    
    # 只有在使用faiss时才导入faiss并进行L2归一化
    if args.library == 'faiss':
        import faiss
        faiss.normalize_L2(X)  # 对数据进行L2归一化
    
    y = np.array(y)
    logging.info("数据加载和预处理完成。")

    # 数据集分割为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
    logging.info("数据集划分为训练集和测试集。")

    # 初始化变量，跟踪最佳的k值和相应的准确率
    best_k = -1
    best_accuracy = 0.0

    # 定义测试的k值范围
    k_values = range(1, 6)
    
    # 根据提供的库选择K近邻算法实现
    if args.library == 'faiss':
        from FaissKNeighbors import FaissKNeighbors
        KNNClass = FaissKNeighbors
        # 初始化FAISS所需的资源（仅FAISS模式下）
        res = None
        try:
            import faiss
            res = faiss.StandardGpuResources() if args.mode == 'gpu' else None
        except ImportError:
            logging.warning("无法导入faiss，将使用CPU模式")
            res = None
        logging.info(f"选择运行模式: {args.mode.upper()}")
    else:
        KNNClass = KNeighborsClassifier
        res = None
        logging.info("使用纯 sklearn 方案，忽略 --mode 参数")
    
    logging.info(f"使用的库为: {args.library.upper()}")

    # 遍历k值，训练并评估模型
    for k in tqdm(k_values, desc='寻找最佳k值'):
        if args.library == 'faiss':
            knn = KNNClass(k=k, res=res)
        else:
            knn = KNNClass(n_neighbors=k)
        knn.fit(X_train, y_train)
        accuracy = knn.score(X_test, y_test)
        
        # 更新最佳k值和准确率
        if accuracy > best_accuracy:
            best_k = k
            best_accuracy = accuracy

    # 打印结果
    logging.info(f'最佳k值: {best_k}, 最高准确率: {best_accuracy}')

# 如果是主脚本，则执行main函数
if __name__ == '__main__':
    main()